Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Nov 11, 2025

This reimplements FastLDF, conceptually in the same way as #1113. Please see that PR for the bulk of the explanation. The difference is that this also unifies the implementation of FastLDF and InitFromParams, such that FastLDF is now actually just InitFromParams but backed by the combination of vector + ranges.

Here's a slightly modified diagram from my slides yesterday:

Diagram showing how InitFromParams and FastLDF are related

Other speedups

Note that this unification also means that other initialisation strategies, i.e. InitFromPrior, InitFromUniform, and other forms of InitFromParams, can also benefit from the speedup (as shown in the top half of the diagram above). This was essentially done in #1125 but lumped into this PR as well. See that PR for benchmarks.

Does this still need to be Experimental?

I'd suggest for this PR yes, if only just to prove correctness compared to old LDF. Making this replace old LDF should be a fairly trivial follow-up. Am open to other ideas.

Does this need to be breaking?

Yes, because the expected return value of DynamicPPL.init has changed. Technically, that wasn't exported, but AbstractInitStrategy was exported, so init was effectively public (it should have been exported).

On top of that, this PR relies on changes in #1133, which are also breaking.

Benchmarks

Performance characteristics are exactly the same as in the original PR #1113. Benchmarks run on Julia 1.11.7, 1 thread.

Benchmarking code
using DynamicPPL, Distributions, LogDensityProblems, Chairmarks, LinearAlgebra
using ADTypes, ForwardDiff, ReverseDiff
@static if VERSION < v"1.12"
    using Enzyme, Mooncake
end

const adtypes = @static if VERSION < v"1.12"
    [
        ("FD", AutoForwardDiff()),
        ("RD", AutoReverseDiff()),
        ("MC", AutoMooncake()),
        ("EN" => AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const))
    ]
else
    [
        ("FD", AutoForwardDiff()),
        ("RD", AutoReverseDiff()),
    ]
end

function benchmark_ldfs(model; skip=Union{})
    vi = VarInfo(model)
    x = vi[:]
    ldf_no = DynamicPPL.LogDensityFunction(model, getlogjoint, vi)
    fldf_no = DynamicPPL.Experimental.FastLDF(model, getlogjoint, vi)
    @assert LogDensityProblems.logdensity(ldf_no, x)  LogDensityProblems.logdensity(fldf_no, x)
    median_old = median(@be LogDensityProblems.logdensity(ldf_no, x))
    print("LogDensityFunction: eval      ----  ")
    display(median_old)
    median_new = median(@be LogDensityProblems.logdensity(fldf_no, x))
    print("           FastLDF: eval      ----  ")
    display(median_new)
    println("                  speedup     ----  ", median_old.time / median_new.time)
    for name_adtype in adtypes
        name, adtype = name_adtype
        adtype isa skip && continue
        ldf = DynamicPPL.LogDensityFunction(model, getlogjoint, vi; adtype=adtype)
        fldf = DynamicPPL.Experimental.FastLDF(model, getlogjoint, vi; adtype=adtype)
        ldf_grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
        fldf_grad = LogDensityProblems.logdensity_and_gradient(fldf, x)
        @assert ldf_grad[2]  fldf_grad[2]
        median_old = median(@be LogDensityProblems.logdensity_and_gradient(ldf, x))
        print("LogDensityFunction: grad ($name) ----  ")
        display(median_old)
        median_new = median(@be LogDensityProblems.logdensity_and_gradient(fldf, x))
        print("           FastLDF: grad ($name) ----  ")
        display(median_new)
        println("                 speedup ($name) ----  ", median_old.time / median_new.time)
    end
end

@model f() = x ~ Normal()
benchmark_ldfs(f())

y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
@model function eight_schools(y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower=0)
    theta ~ MvNormal(fill(mu, length(y)), tau^2 * I)
    for i in eachindex(y)
        y[i] ~ Normal(theta[i], sigma[i])
    end
    return (mu=mu, tau=tau)
end
benchmark_ldfs(eight_schools(y, sigma))

@model function badvarnames()
    N = 20
    x = Vector{Float64}(undef, N)
    for i in 1:N
        x[i] ~ Normal()
    end
end
benchmark_ldfs(badvarnames())

@model function inner()
    m ~ Normal(0, 1)
    s ~ Exponential()
    return (m=m, s=s)
end
@model function withsubmodel()
    params ~ to_submodel(inner())
    y ~ Normal(params.m, params.s)
    1.0 ~ Normal(y)
end
benchmark_ldfs(withsubmodel())

Trivial model

julia> benchmark_ldfs(f())
LogDensityFunction: eval      ----  170.964 ns (6 allocs: 192 bytes)
           FastLDF: eval      ----  10.944 ns
                  speedup     ----  15.621700990952178
LogDensityFunction: grad (FD) ----  317.935 ns (13 allocs: 496 bytes)
           FastLDF: grad (FD) ----  54.127 ns (3 allocs: 96 bytes)
                 speedup (FD) ----  5.873855538906422
LogDensityFunction: grad (RD) ----  4.250 μs (82 allocs: 3.062 KiB)
           FastLDF: grad (RD) ----  3.019 μs (46 allocs: 1.562 KiB)
                 speedup (RD) ----  1.4079581845621525
LogDensityFunction: grad (MC) ----  1.100 μs (25 allocs: 1.219 KiB)
           FastLDF: grad (MC) ----  338.481 ns (4 allocs: 192 bytes)
                 speedup (MC) ----  3.250793303424883
LogDensityFunction: grad (EN) ----  432.455 ns (16 allocs: 560 bytes)
           FastLDF: grad (EN) ----  128.409 ns (2 allocs: 64 bytes)
                 speedup (EN) ----  3.3677876106194686

Eight-schools centred

LogDensityFunction: eval      ----  877.594 ns (21 allocs: 1.344 KiB)
           FastLDF: eval      ----  209.184 ns (4 allocs: 256 bytes)
                  speedup     ----  4.195326219512196
LogDensityFunction: grad (FD) ----  1.611 μs (28 allocs: 5.484 KiB)
           FastLDF: grad (FD) ----  672.465 ns (11 allocs: 2.594 KiB)
                 speedup (FD) ----  2.3956633005948262
LogDensityFunction: grad (RD) ----  40.209 μs (614 allocs: 25.562 KiB)
           FastLDF: grad (RD) ----  38.708 μs (562 allocs: 20.562 KiB)
                 speedup (RD) ----  1.03877751369226
LogDensityFunction: grad (MC) ----  4.528 μs (64 allocs: 4.016 KiB)
           FastLDF: grad (MC) ----  1.183 μs (12 allocs: 784 bytes)
                 speedup (MC) ----  3.8262402956653028
LogDensityFunction: grad (EN) ----  1.858 μs (44 allocs: 2.609 KiB)
           FastLDF: grad (EN) ----  739.026 ns (13 allocs: 832 bytes)
                 speedup (EN) ----  2.5145699058742537

Lots of IndexLenses

LogDensityFunction: eval      ----  1.448 μs (46 allocs: 1.906 KiB)
           FastLDF: eval      ----  459.641 ns (2 allocs: 224 bytes)
                  speedup     ----  3.150069687595608
LogDensityFunction: grad (FD) ----  4.535 μs (103 allocs: 14.266 KiB)
           FastLDF: grad (FD) ----  2.697 μs (11 allocs: 4.281 KiB)
                 speedup (FD) ----  1.6813743665801506
LogDensityFunction: grad (RD) ----  59.584 μs (1076 allocs: 38.828 KiB)
           FastLDF: grad (RD) ----  51.209 μs (773 allocs: 27.438 KiB)
                 speedup (RD) ----  1.1635454705227597
LogDensityFunction: grad (MC) ----  6.656 μs (160 allocs: 7.000 KiB)
           FastLDF: grad (MC) ----  2.229 μs (28 allocs: 1.094 KiB)
                 speedup (MC) ----  2.985981308411215
LogDensityFunction: grad (EN) ----  3.271 μs (64 allocs: 6.141 KiB)
           FastLDF: grad (EN) ----  1.608 μs (5 allocs: 2.188 KiB)
                 speedup (EN) ----  2.0343495042622473

Submodel

LogDensityFunction: eval      ----  867.424 ns (20 allocs: 1.234 KiB)
           FastLDF: eval      ----  103.168 ns
                  speedup     ----  8.407896391298882
LogDensityFunction: grad (FD) ----  1.175 μs (27 allocs: 2.219 KiB)
           FastLDF: grad (FD) ----  187.776 ns (3 allocs: 112 bytes)
                 speedup (FD) ----  6.25744516852358
LogDensityFunction: grad (RD) ----  13.959 μs (221 allocs: 9.266 KiB)
           FastLDF: grad (RD) ----  10.896 μs (148 allocs: 5.188 KiB)
                 speedup (RD) ----  1.2811252351888394
LogDensityFunction: grad (MC) ----  5.750 μs (72 allocs: 3.312 KiB)
           FastLDF: grad (MC) ----  599.667 ns (6 allocs: 240 bytes)
                 speedup (MC) ----  9.588660366870483
LogDensityFunction: grad (EN) ----  2.432 μs (52 allocs: 2.500 KiB)
           FastLDF: grad (EN) ----  341.659 ns (2 allocs: 80 bytes)
                 speedup (EN) ----  7.117680019783942

MCMC

using Turing, Random, LinearAlgebra
y = [28, 8, -3, 7, -1, 1, 18, 12]
sigma = [15, 10, 16, 11, 9, 11, 10, 18]
J = 8
@model function eight_schools(y, sigma)
    mu ~ Normal(0, 5)
    tau ~ truncated(Cauchy(0, 5); lower=0)
    theta ~ MvNormal(fill(mu, length(sigma)), tau^2 * I)
    for i in eachindex(sigma)
        y[i] ~ Normal(theta[i], sigma[i])
    end
    return (mu=mu, tau=tau)
end
model = eight_schools(y, sigma);

using Enzyme, ADTypes
adtype = AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const)
@time sample(model, NUTS(; adtype=adtype), 1000; nadapts=10000, thinning=10, progress=false, verbose=false);

is down from around 8.8 seconds to 1.7 seconds.

@github-actions
Copy link
Contributor

github-actions bot commented Nov 11, 2025

Benchmark Report for Commit 63198fd

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬────────────────┬─────────────────┐
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │ t(eval)/t(ref) │ t(grad)/t(eval) │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼────────────────┼─────────────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │            6.7 │             1.7 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │          741.8 │            43.2 │
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │          434.0 │            56.2 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │          830.6 │            40.2 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │         6668.5 │            25.9 │
│           Smorgasbord │   201 │ forwarddiff │      typed_vector │   true │          766.5 │            41.9 │
│           Smorgasbord │   201 │ forwarddiff │    untyped_vector │   true │          818.9 │            36.4 │
│           Smorgasbord │   201 │ reversediff │             typed │   true │          903.3 │            47.9 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │          733.0 │             5.8 │
│           Smorgasbord │   201 │      enzyme │             typed │   true │          906.8 │             4.1 │
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │         4000.4 │             5.8 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │         1050.9 │             8.8 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │        44154.6 │             5.3 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │         9035.4 │             9.8 │
│               Dynamic │    10 │    mooncake │             typed │   true │          124.0 │            11.8 │
│              Submodel │     1 │    mooncake │             typed │   true │            8.4 │             6.7 │
│                   LDA │    12 │ reversediff │             typed │   true │         1003.8 │             2.0 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴────────────────┴─────────────────┘

@github-actions
Copy link
Contributor

DynamicPPL.jl documentation for PR #1132 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1132/

@penelopeysm penelopeysm force-pushed the py/ldf branch 2 times, most recently from 248f374 to a4c71e6 Compare November 11, 2025 12:33
@penelopeysm penelopeysm changed the base branch from main to breaking November 11, 2025 12:33
This was referenced Nov 11, 2025
@codecov
Copy link

codecov bot commented Nov 11, 2025

Codecov Report

❌ Patch coverage is 94.28571% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.73%. Comparing base (4ca9528) to head (63198fd).
⚠️ Report is 1 commits behind head on breaking.

Files with missing lines Patch % Lines
src/onlyaccs.jl 75.00% 3 Missing ⚠️
src/compiler.jl 85.71% 2 Missing ⚠️
ext/DynamicPPLEnzymeCoreExt.jl 0.00% 1 Missing ⚠️
src/fasteval.jl 98.61% 1 Missing ⚠️
src/utils.jl 66.66% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff              @@
##           breaking    #1132      +/-   ##
============================================
+ Coverage     81.32%   81.73%   +0.41%     
============================================
  Files            40       42       +2     
  Lines          3807     3921     +114     
============================================
+ Hits           3096     3205     +109     
- Misses          711      716       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.


# Test that various different ways of specifying array types as arguments work with all
# ADTypes.
@testset "Array argument types" begin
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this testset is duplicated from LogDensityFunction so nothing new

Comment on lines +39 to +40
The same problem precludes us from eventually broadening the scope of DynamicPPL.jl to
support distributions with non-numeric samples.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran into this issue before, #1003 (comment). The problem is that because we no longer read values from VarInfo, we don't have that information about what the values' eltype are, which can lead to the sort of problems observed in this docstring. (And to be fair, previously we would set values into varinfo and then read eltype(varinfo), which is a terrible idea because parameter types can be homogeneous. Although this seems like more faff, it's probably for the better.)

Comment on lines +223 to +234
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
# it _should_ do, but this is wrong regardless.
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
vi = if Threads.nthreads() > 1
accs = map(
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
accs,
)
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
else
OnlyAccsVarInfo(accs)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +18 to +19
"""
typed_identity(x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think eventually this should go to Bijectors, but we can keep it here for now.

Copy link
Member Author

@penelopeysm penelopeysm Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love it. Just some small bits here and there.

Happy to put this in Experimental, but I wouldn't be offended by immediately or very soon trying to take over LogDensityFunction and seeing what happens, especially in Turing.jl's test suite.

src/fasteval.jl Outdated
- `fastldf.adtype`: The AD type used for gradient calculations, or `nothing` if no AD
type was provided.
## Extended help
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all very useful documentation, and for now this is a good place. Once this comes out of Experimental I wonder if some of the extended stuff should go somewhere else. Maybe either in developer docs, or in a comment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, definitely, somewhere on the developer docs. I am a bit annoyed that I wrote up the models and varinfo evaluation stuff, only to now tear it up (although to be fair writing that was what prompted me to think about this, so... we'll take the win)

@penelopeysm penelopeysm requested a review from mhauru November 12, 2025 16:14
Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks!

@penelopeysm penelopeysm merged commit 535ce4f into breaking Nov 13, 2025
19 checks passed
@penelopeysm penelopeysm deleted the py/ldf branch November 13, 2025 13:30
github-merge-queue bot pushed a commit that referenced this pull request Dec 2, 2025
* v0.39

* Update DPPL compats for benchmarks and docs

* remove merge conflict markers

* Remove `NodeTrait` (#1133)

* Remove NodeTrait

* Changelog

* Fix exports

* docs

* fix a bug

* Fix doctests

* Fix test

* tweak changelog

* FastLDF / InitContext unified (#1132)

* Fast Log Density Function

* Make it work with AD

* Optimise performance for identity VarNames

* Mark `get_range_and_linked` as having zero derivative

* Update comment

* make AD testing / benchmarking use FastLDF

* Fix tests

* Optimise away `make_evaluate_args_and_kwargs`

* const func annotation

* Disable benchmarks on non-typed-Metadata-VarInfo

* Fix `_evaluate!!` correctly to handle submodels

* Actually fix submodel evaluate

* Document thoroughly and organise code

* Support more VarInfos, make it thread-safe (?)

* fix bug in parsing ranges from metadata/VNV

* Fix get_param_eltype for TSVI

* Disable Enzyme benchmark

* Don't override _evaluate!!, that breaks ForwardDiff (sometimes)

* Move FastLDF to experimental for now

* Fix imports, add tests, etc

* More test fixes

* Fix imports / tests

* Remove AbstractFastEvalContext

* Changelog and patch bump

* Add correctness tests, fix imports

* Concretise parameter vector in tests

* Add zero-allocation tests

* Add Chairmarks as test dep

* Disable allocations tests on multi-threaded

* Fast InitContext (#1125)

* Make InitContext work with OnlyAccsVarInfo

* Do not convert NamedTuple to Dict

* remove logging

* Enable InitFromPrior and InitFromUniform too

* Fix `infer_nested_eltype` invocation

* Refactor FastLDF to use InitContext

* note init breaking change

* fix logjac sign

* workaround Mooncake segfault

* fix changelog too

* Fix get_param_eltype for context stacks

* Add a test for threaded observe

* Export init

* Remove dead code

* fix transforms for pathological distributions

* Tidy up loads of things

* fix typed_identity spelling

* fix definition order

* Improve docstrings

* Remove stray comment

* export get_param_eltype (unfortunatley)

* Add more comment

* Update comment

* Remove inlines, fix OAVI docstring

* Improve docstrings

* Simplify InitFromParams constructor

* Replace map(identity, x[:]) with [i for i in x[:]]

* Simplify implementation for InitContext/OAVI

* Add another model to allocation tests

Co-authored-by: Markus Hauru <[email protected]>

* Revert removal of dist argument (oops)

* Format

* Update some outdated bits of FastLDF docstring

* remove underscores

---------

Co-authored-by: Markus Hauru <[email protected]>

* implement `LogDensityProblems.dimension`

* forgot about capabilities...

* use interpolation in run_ad

* Improvements to benchmark outputs (#1146)

* print output

* fix

* reenable

* add more lines to guide the eye

* reorder table

* print tgrad / trel as well

* forgot this type

* Allow generation of `ParamsWithStats` from `FastLDF` plus parameters, and also `bundle_samples` (#1129)

* Implement `ParamsWithStats` for `FastLDF`

* Add comments

* Implement `bundle_samples` for ParamsWithStats -> MCMCChains

* Remove redundant comment

* don't need Statistics?

* Make FastLDF the default (#1139)

* Make FastLDF the default

* Add miscellaneous LogDensityProblems tests

* Use `init!!` instead of `fast_evaluate!!`

* Rename files, rebalance tests

* Implement `predict`, `returned`, `logjoint`, ... with `OnlyAccsVarInfo` (#1130)

* Use OnlyAccsVarInfo for many re-evaluation functions

* drop `fast_` prefix

* Add a changelog

* Improve FastLDF type stability when all parameters are linked or unlinked (#1141)

* Improve type stability when all parameters are linked or unlinked

* fix a merge conflict

* fix enzyme gc crash (locally at least)

* Fixes from review

* Make threadsafe evaluation opt-in (#1151)

* Make threadsafe evaluation opt-in

* Reduce number of type parameters in methods

* Make `warned_warn_about_threads_threads_threads_threads` shorter

* Improve `setthreadsafe` docstring

* warn on bare `@threads` as well

* fix merge

* Fix performance issues

* Use maxthreadid() in TSVI

* Move convert_eltype code to threadsafe eval function

* Point to new Turing docs page

* Add a test for setthreadsafe

* Tidy up check_model

* Apply suggestions from code review

Fix outdated docstrings

Co-authored-by: Markus Hauru <[email protected]>

* Improve warning message

* Export `requires_threadsafe`

* Add an actual docstring for `requires_threadsafe`

---------

Co-authored-by: Markus Hauru <[email protected]>

* Standardise `:lp` -> `:logjoint` (#1161)

* Standardise `:lp` -> `:logjoint`

* changelog

* fix a test

---------

Co-authored-by: Markus Hauru <[email protected]>
Co-authored-by: Markus Hauru <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants